#!/usr/bin/env python
# pylint: disable=missing-docstring
# -*- coding: utf-8 -*-

"""Trains, evaluates and saves the model network using a queue."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import imp
import json
import logging
import numpy as np
import os.path
import sys

import scipy as scp
import scipy.misc

import overlay_utils as overlay

from kitti_devkit import seg_utils as seg

# configure logging
if 'TV_IS_DEV' in os.environ and os.environ['TV_IS_DEV']:
    logging.basicConfig(format='%(asctime)s %(levelname)s %(message)s',
                        level=logging.INFO,
                        stream=sys.stdout)
else:
    logging.basicConfig(format='%(asctime)s %(levelname)s %(message)s',
                        level=logging.INFO,
                        stream=sys.stdout)


import time

from shutil import copyfile

from six.moves import xrange  # pylint: disable=redefined-builtin

import tensorflow as tf

import tensorvision.utils as utils
import tensorvision.core as core
import tensorvision.analyze as ana

flags = tf.app.flags
FLAGS = flags.FLAGS

flags.DEFINE_bool('kitti_eval', True,
                  'Do full epoche of Kitti train Evaluation.')

flags.DEFINE_bool('inspect', False,
                  'Do full epoche of Kitti train Evaluation.')


def _load_weights(checkpoint_dir, sess, saver):
    ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
    if ckpt and ckpt.model_checkpoint_path:
        logging.info(ckpt.model_checkpoint_path)
        if os.path.exists(ckpt.model_checkpoint_path):
            saver.restore(sess, ckpt.model_checkpoint_path)
        else:
            weight_file = os.path.basename(ckpt.model_checkpoint_path)
            checkpoint = os.path.join(checkpoint_dir, weight_file)
            if not os.path.exists(checkpoint):
                logging.error("File not found: %s", ckpt.model_checkpoint_path)
                logging.error("File not found: %s", checkpoint)
                logging.error("Could not find weights.")
                exit(1)
            saver.restore(sess, checkpoint)


def _add_softmax(hypes, logits):
    with tf.name_scope('decoder'):
        logits = tf.reshape(logits, (-1, 2))
        epsilon = tf.constant(value=hypes['solver']['epsilon'])
        logits = logits + epsilon

        softmax = tf.nn.softmax(logits)

    return softmax


def _create_input_placeholder():
    image_pl = tf.placeholder(tf.float32)
    label_pl = tf.placeholder(tf.float32)
    return image_pl, label_pl


def build_inference_graph(hypes, modules, image, label):
    """Run one evaluation against the full epoch of data.

    Parameters
    ----------
    hypes : dict
        Hyperparameters
    modules : tuble
        the modules load in utils
    image : placeholder
    label : placeholder

    return:
        graph_ops
    """
    data_input, arch, objective, solver = modules

    logits = arch.inference(hypes, image, train=False)

    decoder = objective.decoder(hypes, logits)

    softmax_layer = _add_softmax(hypes, decoder)

    return softmax_layer


def _prepare_output_folder(hypes, logdir):
    # output_dir = os.path.basename(data_files).split(".")[0]
    output_dir = os.path.join(os.path.realpath(logdir), 'eval')
    if not os.path.exists(output_dir):
        os.mkdir(output_dir)
        # logging.error("Path exists: %s", output_dir)
        # logging.error("Please move dir or rename %s", data_files)
        # exit(1)
    hypes['dirs']['eval_out'] = output_dir


def eval_image(hypes, gt_image, cnn_image):
    """."""
    thresh = np.array(range(0, 256))/255.0
    road_gt = gt_image[:, :, 2] > 0
    valid_gt = gt_image[:, :, 0] > 0

    FN, FP, posNum, negNum = seg.evalExp(road_gt, cnn_image,
                                         thresh, validMap=None,
                                         validArea=valid_gt)

    return FN, FP, posNum, negNum

import random


def tensor_eval(hypes, sess, image_pl, softmax):
    data_dir = hypes['dirs']['data_dir']
    data_file = hypes['data']['val_file']
    data_file = os.path.join(data_dir, data_file)
    image_dir = os.path.dirname(data_file)

    thresh = np.array(range(0, 256))/255.0
    total_fp = np.zeros(thresh.shape)
    total_fn = np.zeros(thresh.shape)
    total_posnum = 0
    total_negnum = 0

    images = []

    with open(data_file) as file:
        for datum in file:
                datum = datum.rstrip()
                image_file, gt_file = datum.split(" ")
                image_file = os.path.join(image_dir, image_file)
                gt_file = os.path.join(image_dir, gt_file)

                image = scp.misc.imread(image_file)
                gt_image = scp.misc.imread(gt_file)
                shape = image.shape

                feed_dict = {image_pl: image}

                start_time = time.time()
                output = sess.run([softmax], feed_dict=feed_dict)
                duration = time.time() - start_time
                duration = float(duration)*1000
                output_im = output[0][:, 1].reshape(shape[0], shape[1])
                print('Duration %.3f ms' % (duration))

                if random.random() < 0.1:
                    ov_image = seg.make_overlay(image, output_im)
                    images.append(ov_image)

                FN, FP, posNum, negNum = eval_image(hypes, gt_image, output_im)

                total_fp += FP
                total_fn += FN
                total_posnum += posNum
                total_negnum += negNum

    eval_dict = seg.pxEval_maximizeFMeasure(total_posnum, total_negnum,
                                            total_fn, total_fp,
                                            thresh=thresh)

    eval_list = []

    eval_list.append(('MaxF1', 100*eval_dict['MaxF']))
    eval_list.append(('BestThresh', 100*eval_dict['BestThresh']))
    eval_list.append(('Average Precision', 100*eval_dict['AvgPrec']))

    return eval_list, images


def eval_dataset(hypes, data_file, save_overlay,
                 sess, image_pl, softmax):
    """Run Varies Evaluation on the Dataset."""
    image_dir = os.path.dirname(data_file)

    im_name = os.path.basename(data_file).split(".")[0]
    if save_overlay:
        eval_out = os.path.join(hypes['dirs']['eval_out'], im_name)
        if not os.path.exists(eval_out):
            os.mkdir(eval_out)

    thresh = np.array(range(0, 256))/255.0
    total_fp = np.zeros(thresh.shape)
    total_fn = np.zeros(thresh.shape)
    total_posnum = 0
    total_negnum = 0

    with open(data_file) as file:
        for datum in file:
                datum = datum.rstrip()
                image_file, gt_file = datum.split(" ")
                image_file = os.path.join(image_dir, image_file)
                gt_file = os.path.join(image_dir, gt_file)

                image = scp.misc.imread(image_file)
                gt_image = scp.misc.imread(gt_file)
                shape = image.shape

                feed_dict = {image_pl: image}
                start_time = time.time()
                output = sess.run([softmax], feed_dict=feed_dict)
                duration = time.time() - start_time
                duration = float(duration)*1000
                print('Duration %.3f ms' % (duration))
                output_im = output[0][:, 1].reshape(shape[0], shape[1])

                if save_overlay:
                    ov_image = seg.make_overlay(image, output_im)
                    image_name = os.path.basename(image_file).split('.')[0]
                    ov_name = image_name + "_ov.png"
                    ov_name = os.path.join(eval_out, ov_name)
                    scp.misc.imsave(ov_name, ov_image)

                FN, FP, posNum, negNum = eval_image(hypes, gt_image, output_im)

                total_fp += FP
                total_fn += FN
                total_posnum += posNum
                total_negnum += negNum

    return seg.pxEval_maximizeFMeasure(total_posnum, total_negnum,
                                       total_fn, total_fp,
                                       thresh=thresh)


def _get_filewrite_handler(logging_file):
    filewriter = logging.FileHandler(logging_file, mode='w')
    formatter = logging.Formatter(
        '%(asctime)s %(name)-3s %(levelname)-3s %(message)s')
    filewriter.setLevel(logging.INFO)
    filewriter.setFormatter(formatter)
    return filewriter


def do_kitti_eval_with_training_data(hypes, sess, image_pl, softmax):
    """."""
    show_dict = ['MaxF', 'BestThresh', 'AvgPrec', 'PRE_wp', 'REC_wp',
                 'FPR_wp', 'FNR_wp']
    eval_files = ['val3.txt', 'um.txt', 'umm.txt', 'uu.txt']
    eval_names = ['Validation Data', 'Urban Marked',
                  'Urban Multi Marked', 'Urban Unmarked']

    logging_file = os.path.join(hypes['dirs']['eval_out'], 'kitti.log')
    filewriter = _get_filewrite_handler(logging_file)
    rootlog = logging.getLogger('')
    rootlog.addHandler(filewriter)

    for file, name in zip(eval_files, eval_names):
        data_file = os.path.join('data_road', file)
        data_file = os.path.join(hypes['dirs']['data_dir'], data_file)
        data_file = os.path.realpath(data_file)

        eval_val_dict = eval_dataset(hypes, data_file, False,
                                     sess, image_pl, softmax)

        logging.info("Results for %s Data.", name)

        factor = 100

        for metric in show_dict:
            logging.info('%s: %4.2f', metric, eval_val_dict[metric]*factor)

    rootlog.removeHandler(filewriter)


def do_inference(hypes, modules, logdir):
    """
    Analyze a trained model.

    This will load model files and weights found in logdir and run a basic
    analysis.

    Paramters
    ---------
    logdir : string
        folder with logs
    """
    data_input, arch, objective, solver = modules

    data_dir = hypes['dirs']['data_dir']
    if 'TV_DIR_DATA' in os.environ:
        data_dir = os.environ['TV_DIR_DATA']
        hypes['dirs']['data_dir'] = data_dir
        hypes['dirs']['output_dir'] = logdir

    # Tell TensorFlow that the model will be built into the default Graph.
    with tf.Graph().as_default():

        image_pl, label_pl = _create_input_placeholder()

        image = tf.expand_dims(image_pl, 0)

        if 'whitening' not in hypes['arch'] or \
                hypes['arch']['whitening']:
            image = tf.image.per_image_whitening(image)
            logging.info('Whitening is enabled.')
        else:
            logging.info('Whitening is disabled.')

        # build the graph based on the loaded modules
        softmax = build_inference_graph(hypes, modules,
                                        image=image, label=label_pl)

        # prepaire the tv session
        sess_coll = core.start_tv_session(hypes)
        sess, saver, summary_op, summary_writer, coord, threads = sess_coll

        _load_weights(logdir, sess, saver)

    _prepare_output_folder(hypes, logdir)

    val_json = os.path.join(hypes['dirs']['eval_out'], 'val.json')

    if FLAGS.inspect:
        if not os.path.exists(val_json):
            logging.error("File does not exist: %s", val_json)
            logging.error("Please run kitti_eval in normal mode first.")
            exit(1)
        else:
            with open(val_json, 'r') as f:
                eval_dict = json.load(f)
                logging.debug(eval_dict)
                from IPython import embed
                embed()
                exit(0)

    logging.info("Doing evaluation with Validation Data")
    val_file = os.path.join(hypes['dirs']['data_dir'],
                            hypes['data']['val_file'])
    eval_dict = eval_dataset(hypes, val_file, True, sess, image_pl, softmax)

    with open(val_json, 'w') as outfile:
        # json.dump(eval_dict, outfile, indent=2)
        logging.info("Dumping currently not supported")

    logging.info("Succesfully evaluated Dataset. Output is written to %s",
                 val_json)

    logging_file = os.path.join(hypes['dirs']['eval_out'], 'eval.log')
    filewriter = _get_filewrite_handler(logging_file)
    rootlog = logging.getLogger('')
    rootlog.addHandler(filewriter)

    logging.info('Statistics on Validation Data.')

    logging.info('MaxF1          : %4.2f', 100*eval_dict['MaxF'])
    logging.info('BestThresh     : %4.2f', 100*eval_dict['BestThresh'])
    logging.info('Avg Precision  : %4.2f', 100*eval_dict['AvgPrec'])
    logging.info('')
    ind5 = np.where(eval_dict['thresh'] >= 0.5)[0][0]
    logging.info('Precision @ 0.5: %4.2f', 100*eval_dict['precision'][ind5])
    logging.info('Recall    @ 0.5: %4.2f', 100*eval_dict['recall'][ind5])
    logging.info('TPR       @ 0.5: %4.2f', 100*eval_dict['recall'][ind5])
    logging.info('TNR       @ 0.5: %4.2f', 100*eval_dict['TNR'][ind5])

    if FLAGS.kitti_eval:
        do_kitti_eval_with_training_data(hypes, sess, image_pl, softmax)

    rootlog.removeHandler(filewriter)

    ana.do_analyze(FLAGS.logdir)
    # training_cats = []


def main(_):
    """Run main function."""
    if FLAGS.logdir is None:
        logging.error("No logdir are given.")
        logging.error("Usage: tv-analyze --logdir dir")
        exit(1)

    if FLAGS.gpus is None:
        if 'TV_USE_GPUS' in os.environ:
            if os.environ['TV_USE_GPUS'] == 'force':
                logging.error('Please specify a GPU.')
                logging.error('Usage tv-train --gpus <ids>')
                exit(1)
            else:
                gpus = os.environ['TV_USE_GPUS']
                logging.info("GPUs are set to: %s", gpus)
                os.environ['CUDA_VISIBLE_DEVICES'] = gpus
    else:
        logging.info("GPUs are set to: %s", FLAGS.gpus)
        os.environ['CUDA_VISIBLE_DEVICES'] = FLAGS.gpus

    utils.load_plugins()

    logdir = os.path.realpath(FLAGS.logdir)

    hypes = utils.load_hypes_from_logdir(logdir)
    modules = utils.load_modules_from_logdir(logdir)

    logging.info("Starting to analyze Model in: %s", logdir)
    do_inference(hypes, modules, logdir)


if __name__ == '__main__':
    tf.app.run()
